Skip to content

Conversation

@Red-Portal
Copy link
Member

@Red-Portal Red-Portal commented Nov 21, 2025

Consider the case where we would like to approximate a constrained target distribution with density $\pi : \mathcal{X} \to \mathbb{R}{> 0}$ with an unconstrained variational approximation with density $q : \mathbb{R}^d \to \mathbb{R}{> 0}$. The canonical way to deal with this, popularized by the ADVI paper1, is to use a $b$ bijective transformation ("Bijectors") $b : \mathbb{R}^d \to \mathcal{X}$ such that $q$ is augmented into $q_{b}$ as

$$q_{b^{-1}}(z) = q(b^{-1}(z)) {\lvert \mathrm{J}_{b^{-1}}(z) \rvert}$$

Then AdvancedVI needs to solve the problem

$$q_{b^{-1}}^* = \arg\min_{q \in \mathcal{Q}} \;\; \mathrm{D}(q_{b^{-1}}, \pi_b) .$$

But notice that the optimization is, in reality, over $q$. Therefore, often times, AdvancedVI needs access to the underlying q. I will refer to this as the "primal" scheme.

Previously, this was done by giving a special treatment to q <: Bijectors.TransformedDistribution through the Bijectors extension. In particular, the Bijectors extension had to add a specialization to a lot of methods that simply unwrap a TransformedDistribution to do something. This behavior is difficult to document and, therefore, wasn't fully explained in the documentation. Furthermore, each of the relevant methods needs to be specialized in the Bijectors extension, which resulted in a multiplicative complexity, especially for unit testing.

This, however, is unnecessary. Instead, there exists an equivalent "dual" problem that operates in unconstrained space by approximating the transformed posterior

$$\pi_b(\eta) = \pi(b^{-1}(\eta)) {\lvert \mathrm{J}_{b^{-1}}(\eta) \rvert} .$$

That is, we can solve the problem

$$q^* = \arg\min_{q \in \mathcal{Q}} \;\; \mathrm{D}(q, \pi_b)$$

and then post-process the output to retrieve $q_{b^{-1}}^*$.

Within this context, this PR removes the Bijectors extension to fix this problem. Here are the reationals:

  • As mentioned above, AdvancedVI doesn't need to implement the primal scheme. In fact, the upcoming interface in Turing is planned to implement the dual scheme above.
  • The new algorithms KLMinNaturalGradDescent, KLMinWassFwdBwd, FisherMinBatchMatch, for example, do not work in constrained support at all, so they can only be used via the dual scheme. So the way that KLMinRepGradDescent and friends implemented the primal scheme is a bit redundant in terms of consistency at this point.

Instead, a tutorial has been added to the documentation on how to use VI with constrained supports via the dual scheme.

Footnotes

  1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research, 18(14), 1-45.


include("normallognormal.jl")
include("unconstrdist.jl")
struct Dist{D<:ContinuousMultivariateDistribution}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The content of unconstrdist.jl have been moved here.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 8c657e0 Previous: f9d7f0b Ratio
normal/RepGradELBO + STL/meanfield/Zygote 2215502796 ns 2640111495.5 ns 0.84
normal/RepGradELBO + STL/meanfield/ReverseDiff 571883445 ns 611337145 ns 0.94
normal/RepGradELBO + STL/meanfield/Mooncake 196000337.5 ns 246507667 ns 0.80
normal/RepGradELBO + STL/fullrank/Zygote 1708922306 ns 2061469121 ns 0.83
normal/RepGradELBO + STL/fullrank/ReverseDiff 1103359376 ns 1166722255 ns 0.95
normal/RepGradELBO + STL/fullrank/Mooncake 485775019 ns 681311257.5 ns 0.71
normal/RepGradELBO/meanfield/Zygote 1269167076.5 ns 1592865651 ns 0.80
normal/RepGradELBO/meanfield/ReverseDiff 276315120 ns 304657276 ns 0.91
normal/RepGradELBO/meanfield/Mooncake 144033159.5 ns 174167780 ns 0.83
normal/RepGradELBO/fullrank/Zygote 832644703 ns 1116928390 ns 0.75
normal/RepGradELBO/fullrank/ReverseDiff 533858360 ns 605576908 ns 0.88
normal/RepGradELBO/fullrank/Mooncake 403687069 ns 563720905 ns 0.72

This comment was automatically generated by workflow using github-action-benchmark.

Red-Portal and others added 10 commits November 22, 2025 11:53
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@github-actions
Copy link
Contributor

AdvancedVI.jl documentation for PR #219 is available at:
https://TuringLang.github.io/AdvancedVI.jl/previews/PR219/

Red-Portal and others added 2 commits November 22, 2025 15:09
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@Red-Portal
Copy link
Member Author

The updates to the documentation and README have been suppressed for clarity and will be added later once the PR is approved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants